{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "rational-latitude",
   "metadata": {},
   "source": [
    "# PyTorch Distributed and data-parallel training\n",
    "In this notebook, we'll be overview the distributed part of the PyTorch library. We will also see a couple of examples of distributed training using available wrappers.\n",
    "\n",
    "To find out about the communication pattern of your GPUs, you can use the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa782c3d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!nvidia-smi topo -m"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "522e105b",
   "metadata": {},
   "source": [
    "Let's import all required libraries and define a function which will create the process group. There are [three](https://pytorch.org/docs/stable/distributed.html#backends-that-come-with-pytorch) communication backends in PyTorch: as a simple rule, use GLOO for CPU communication and NCCL for communication between NVIDIA GPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "divided-woman",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "from torch.multiprocessing import Process\n",
    "import random\n",
    "\n",
    "def init_process(rank, size, fn, master_port, backend='gloo'):\n",
    "    \"\"\" Initialize the distributed environment. \"\"\"\n",
    "    os.environ['MASTER_ADDR'] = '127.0.0.1'\n",
    "    os.environ['MASTER_PORT'] = str(master_port)\n",
    "    dist.init_process_group(backend, rank=rank, world_size=size)\n",
    "    fn(rank, size)\n",
    "    \n",
    "torch.set_num_threads(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "infinite-chair",
   "metadata": {},
   "source": [
    "First, we'll run a very simple function with torch.distributed.barrier. The cell below prints in the first process and then prints in all other processes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "hungry-chest",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def run(rank, size):\n",
    "    \"\"\" Distributed function to be implemented later. \"\"\"\n",
    "    if rank!=0:\n",
    "        dist.barrier()\n",
    "    print(f'Started {rank}',flush=True)\n",
    "    if rank==0:\n",
    "        dist.barrier()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    size = 4\n",
    "    processes = []\n",
    "    port = random.randint(25000, 30000)\n",
    "    for rank in range(size):\n",
    "        p = Process(target=init_process, args=(rank, size, run, port))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "living-intermediate",
   "metadata": {},
   "source": [
    "Let's implement a classical ping-pong application with this paradigm. We have two processes, and the goal is to have P1 output 'ping' and P2 output 'pong' without any race conditions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "closed-nashville",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def run_pingpong(rank, size, num_iter=10):\n",
    "    \"\"\" Distributed function to be implemented later. \"\"\"\n",
    "    \n",
    "    \n",
    "    for _ in range(num_iter):\n",
    "        pass\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    size = 2\n",
    "    processes = []\n",
    "    port = random.randint(25000, 30000)\n",
    "    for rank in range(size):\n",
    "        p = Process(target=init_process, args=(rank, size, run_pingpong, port))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "therapeutic-foster",
   "metadata": {},
   "source": [
    "# Point-to-point communication\n",
    "The functions below show that it's possible to send data from one process to another with `torch.distributed.send/torch.distributed.recv`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "varying-nightmare",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\"\"\"Blocking point-to-point communication.\"\"\"\n",
    "\n",
    "def run_sendrecv(rank, size):\n",
    "    tensor = torch.zeros(1)+int(rank==0)\n",
    "    print('Rank ', rank, ' has data ', tensor[0], flush=True)\n",
    "    if rank == 0:\n",
    "        # Send the tensor to process 1\n",
    "        dist.send(tensor=tensor, dst=1)\n",
    "    else:\n",
    "        # Receive tensor from process 0\n",
    "        dist.recv(tensor=tensor, src=0)\n",
    "    print('Rank ', rank, ' has data ', tensor[0], flush=True)\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    size = 2\n",
    "    processes = []\n",
    "    port = random.randint(25000, 30000)\n",
    "    for rank in range(size):\n",
    "        p = Process(target=init_process, args=(rank, size, run_sendrecv, port))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a625f88a",
   "metadata": {},
   "source": [
    "Also, these functions have an immediate (asynchronous) version:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "statistical-serve",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\"\"\"Non-blocking point-to-point communication.\"\"\"\n",
    "import time\n",
    "\n",
    "def run_isendrecv(rank, size):\n",
    "    tensor = torch.zeros(1)\n",
    "    req = None\n",
    "    if rank == 0:\n",
    "        tensor += 1\n",
    "        # Send the tensor to process 1\n",
    "        req = dist.isend(tensor=tensor, dst=1)\n",
    "        print('Rank 0 started sending')\n",
    "    else:\n",
    "        # Receive tensor from process 0\n",
    "        req = dist.irecv(tensor=tensor, src=0)\n",
    "        print('Rank 1 started receiving')\n",
    "        \n",
    "    print('Rank ', rank, ' has data ', tensor[0])\n",
    "    req.wait()\n",
    "    print('Rank ', rank, ' has data ', tensor[0])\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    size = 2\n",
    "    processes = []\n",
    "    port = random.randint(25000, 30000)\n",
    "    for rank in range(size):\n",
    "        p = Process(target=init_process, args=(rank, size, run_isendrecv, port))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "830a75ce",
   "metadata": {},
   "source": [
    "Adding an artificial delay shows that the communication is asynchronous:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc124708",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "def run_isendrecv(rank, size):\n",
    "    tensor = torch.zeros(1)\n",
    "    req = None\n",
    "    if rank == 0:\n",
    "        tensor += 1\n",
    "        # Send the tensor to process 1\n",
    "        time.sleep(5)\n",
    "        req = dist.isend(tensor=tensor, dst=1)\n",
    "        print('Rank 0 started sending')\n",
    "    else:\n",
    "        # Receive tensor from process 0\n",
    "        req = dist.irecv(tensor=tensor, src=0)\n",
    "        print('Rank 1 started receiving')\n",
    "        \n",
    "    print('Rank ', rank, ' has data ', tensor[0])\n",
    "    req.wait()\n",
    "    print('Rank ', rank, ' has data ', tensor[0])\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    size = 2\n",
    "    processes = []\n",
    "    port = random.randint(25000, 30000)\n",
    "    for rank in range(size):\n",
    "        p = Process(target=init_process, args=(rank, size, run_isendrecv, port))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "generous-sacramento",
   "metadata": {},
   "source": [
    "# Collective communication and All-Reduce\n",
    "Now, let's run a simple All-Reduce example which computes the sum across all workers. We'll be running the code with the `!python` magic to avoid issues caused by interaction of Jupyter and multiprocessing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "unsigned-conservative",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%writefile run_allreduce.py\n",
    "#!/usr/bin/env python\n",
    "import os\n",
    "from functools import partial\n",
    "\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "from torch.multiprocessing import Process\n",
    "\n",
    "def run_allreduce(rank, size):\n",
    "    tensor = torch.ones(1)\n",
    "    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n",
    "    print('Rank ', rank, ' has data ', tensor[0])\n",
    "    \n",
    "def init_process(rank, size, fn, backend='gloo'):\n",
    "    \"\"\" Initialize the distributed environment. \"\"\"\n",
    "    os.environ['MASTER_ADDR'] = '127.0.0.1'\n",
    "    os.environ['MASTER_PORT'] = '29500'\n",
    "    dist.init_process_group(backend, rank=rank, world_size=size)\n",
    "    fn(rank, size)\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    size = 10\n",
    "    processes = []\n",
    "    for rank in range(size):\n",
    "        p = Process(target=init_process, args=(rank, size, run_allreduce))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "moderate-latex",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python run_allreduce.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83b8cf8a",
   "metadata": {},
   "source": [
    "The same thing can be done with a simpler [torch.multiprocessing.spawn](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) wrapper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "chronic-tokyo",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%writefile run_allreduce_spawn.py\n",
    "#!/usr/bin/env python\n",
    "import os\n",
    "from functools import partial\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "\n",
    "def run_allreduce(rank, size):\n",
    "    tensor = torch.ones(1)\n",
    "    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n",
    "    print('Rank ', rank, ' has data ', tensor[0])\n",
    "    \n",
    "def init_process(rank, size, fn, backend='gloo'):\n",
    "    \"\"\" Initialize the distributed environment. \"\"\"\n",
    "    os.environ['MASTER_ADDR'] = '127.0.0.1'\n",
    "    os.environ['MASTER_PORT'] = '29500'\n",
    "    dist.init_process_group(backend, rank=rank, world_size=size)\n",
    "    fn(rank, size)\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    size = 10\n",
    "\n",
    "    fn = partial(init_process, size=size, fn=run_allreduce, backend='gloo')\n",
    "    torch.multiprocessing.spawn(fn, nprocs=size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "extreme-trigger",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!./run_allreduce_spawn.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d735d0df",
   "metadata": {},
   "source": [
    "Let's write our own Butterfly All-Reduce. First, we start with creating 5 random vectors and getting the \"true\" average, just for comparison:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffc59d3c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "size = 5\n",
    "tensors = []\n",
    "\n",
    "for i in range(size):\n",
    "    torch.manual_seed(i)\n",
    "    cur_tensor = torch.randn((size,), dtype=torch.float)\n",
    "    print(cur_tensor)\n",
    "    tensors.append(cur_tensor)\n",
    "    \n",
    "print(\"result\", torch.stack(tensors).mean(0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad9ac2f8",
   "metadata": {},
   "source": [
    "Now, let's create a custom implementation below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56a97730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%writefile custom_allreduce.py\n",
    "import os\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "from torch.multiprocessing import Process\n",
    "import random\n",
    "\n",
    "def init_process(rank, size, fn, master_port, backend='gloo'):\n",
    "    \"\"\" Initialize the distributed environment. \"\"\"\n",
    "    os.environ['MASTER_ADDR'] = '127.0.0.1'\n",
    "    os.environ['MASTER_PORT'] = str(master_port)\n",
    "    dist.init_process_group(backend, rank=rank, world_size=size)\n",
    "    fn(rank, size)\n",
    "\n",
    "def butterfly_allreduce(send, rank, size):\n",
    "    \"\"\"\n",
    "    Performs Butterfly All-Reduce over the process group.\n",
    "    Args:\n",
    "        send: torch.Tensor to be averaged with other processes.\n",
    "        rank: Current process rank (in a range from 0 to size)\n",
    "        size: Number of workers\n",
    "    \"\"\"\n",
    "    \n",
    "    buffer_for_chunk = torch.empty((size,), dtype=torch.float)\n",
    "    \n",
    "    send_futures = []\n",
    "    \n",
    "    for i, elem in enumerate(send):\n",
    "        if i!=rank:\n",
    "            send_futures.append(dist.isend(elem, i))\n",
    "            \n",
    "    recv_futures = []\n",
    "    \n",
    "    for i, elem in enumerate(buffer_for_chunk):\n",
    "        if i!=rank:\n",
    "            recv_futures.append(dist.irecv(elem, i))\n",
    "        else:\n",
    "            elem.copy_(send[i])\n",
    "            \n",
    "    for future in recv_futures:\n",
    "        future.wait()\n",
    "        \n",
    "    # compute the average\n",
    "    torch.mean(buffer_for_chunk, dim=0, out=send[rank])\n",
    "    \n",
    "    for i in range(size):\n",
    "        if i!=rank:\n",
    "            send_futures.append(dist.isend(send[rank], i))\n",
    "            \n",
    "    recv_futures = []\n",
    "    \n",
    "    for i, elem in enumerate(send):\n",
    "        if i!=rank:\n",
    "            recv_futures.append(dist.irecv(elem, i))\n",
    "    \n",
    "    for future in recv_futures:\n",
    "        future.wait()\n",
    "    for future in send_futures:\n",
    "        future.wait()\n",
    "            \n",
    "\n",
    "def run_allreduce(rank, size):\n",
    "    \"\"\" Simple point-to-point communication. \"\"\"\n",
    "    torch.manual_seed(rank)\n",
    "    tensor = torch.randn((size,), dtype=torch.float)\n",
    "    print('Rank ', rank, ' has data ', tensor)\n",
    "    butterfly_allreduce(tensor, rank, size)\n",
    "    print('Rank ', rank, ' has data ', tensor)\n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    size = 5\n",
    "    processes = []\n",
    "    port = random.randint(25000, 30000)\n",
    "    for rank in range(size):\n",
    "        p = Process(target=init_process, args=(rank, size, run_allreduce, port))\n",
    "        p.start()\n",
    "        processes.append(p)\n",
    "\n",
    "    for p in processes:\n",
    "        p.join()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9547e670",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python custom_allreduce.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "governmental-lesson",
   "metadata": {},
   "source": [
    "# Distributed training\n",
    "\n",
    "Now that we have this simple implementation of AllReduce, we can run multi-process distributed training. For now, let's use the model and the dataset from the official MNIST [example](https://github.com/pytorch/examples/blob/master/mnist/main.py), as well as the [torchrun](https://pytorch.org/docs/stable/elastic/run.html?highlight=torchrun) command used to manage processes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "instructional-dictionary",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%writefile custom_allreduce_training.py\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "\n",
    "def init_process(local_rank, fn, backend='nccl'):\n",
    "    \"\"\" Initialize the distributed environment. \"\"\"\n",
    "    dist.init_process_group(backend, rank=local_rank)\n",
    "    size = dist.get_world_size()\n",
    "    fn(local_rank, size)\n",
    "\n",
    "\n",
    "torch.set_num_threads(1)\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, 3, 1)\n",
    "        self.dropout1 = nn.Dropout(0.25)\n",
    "        self.dropout2 = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(4608, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        output = self.fc2(x)\n",
    "        return output\n",
    "\n",
    "\n",
    "def average_gradients(model):\n",
    "    size = float(dist.get_world_size())\n",
    "    for param in model.parameters():\n",
    "        dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)\n",
    "        param.grad.data /= size\n",
    "\n",
    "\n",
    "def run_training(rank, size):\n",
    "    torch.manual_seed(1234)\n",
    "    dataset = MNIST('./mnist', download=True, transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ]))\n",
    "    loader = DataLoader(dataset, sampler=DistributedSampler(dataset, size, rank), batch_size=16)\n",
    "    model = Net()\n",
    "    device = torch.device('cpu')\n",
    "    model.to(device)\n",
    "    optimizer = torch.optim.SGD(model.parameters(),\n",
    "                                lr=0.01, momentum=0.5)\n",
    "\n",
    "    num_batches = len(loader)\n",
    "    steps = 0\n",
    "    epoch_loss = 0\n",
    "    for data, target in loader:\n",
    "        data = data.to(device)\n",
    "        target = target.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = torch.nn.functional.cross_entropy(output, target)\n",
    "        epoch_loss += loss.item()\n",
    "        loss.backward()\n",
    "        average_gradients(model)\n",
    "        optimizer.step()\n",
    "        steps += 1\n",
    "        if True:\n",
    "            print(f'Rank {dist.get_rank()}, loss: {epoch_loss / num_batches}')\n",
    "            epoch_loss = 0\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    local_rank = int(os.environ[\"LOCAL_RANK\"])\n",
    "    init_process(local_rank, fn=run_training, backend='gloo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "compound-trustee",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!torchrun --nproc_per_node 2 custom_allreduce_training.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c69ba92a",
   "metadata": {},
   "source": [
    "Now let's use the standard [DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) wrapper (which you should probably use in real-world training anyway):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304ae9a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%writefile ddp_example.py\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.nn.parallel import DistributedDataParallel\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "\n",
    "def init_process(local_rank, fn, backend='nccl'):\n",
    "    \"\"\" Initialize the distributed environment. \"\"\"\n",
    "    dist.init_process_group(backend, rank=local_rank)\n",
    "    size = dist.get_world_size()\n",
    "    fn(local_rank, size)\n",
    "\n",
    "\n",
    "torch.set_num_threads(1)\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 32, 3, 1)\n",
    "        self.dropout1 = nn.Dropout(0.25)\n",
    "        self.dropout2 = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(4608, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        output = self.fc2(x)\n",
    "        return output\n",
    "\n",
    "def run_training(rank, size):\n",
    "    torch.manual_seed(1234)\n",
    "    dataset = MNIST('./mnist', download=True, transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ]))\n",
    "    loader = DataLoader(dataset,\n",
    "                        sampler=DistributedSampler(dataset, size, rank),\n",
    "                        batch_size=16)\n",
    "    model = Net()\n",
    "    device = torch.device('cuda', rank)\n",
    "    model.to(device)\n",
    "    \n",
    "    model = DistributedDataParallel(model, device_ids=[rank], output_device=rank)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n",
    "\n",
    "    num_batches = len(loader)\n",
    "    steps = 0\n",
    "    epoch_loss = 0\n",
    "    for data, target in loader:\n",
    "        data = data.to(device)\n",
    "        target = target.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = torch.nn.functional.cross_entropy(output, target)\n",
    "        epoch_loss += loss.item()\n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "        steps += 1\n",
    "        if True:\n",
    "            print(f'Rank {dist.get_rank()}, loss: {epoch_loss / num_batches}')\n",
    "            epoch_loss = 0\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    local_rank = int(os.environ[\"LOCAL_RANK\"])\n",
    "    init_process(local_rank, fn=run_training, backend='gloo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03925f20",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!torchrun --nproc_per_node 2 ddp_example.py"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acad0a5f",
   "metadata": {},
   "source": [
    "That's it for today! For the homework this week, see the `homework` folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cc26a8f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
